import numpy as np
import faiss

class RecallCompute(object):
    def __init__(self, gallery_feats, query_feats, labels):
        if len(query_feats) == 2 and type(query_feats) == list:
            """
            feats = [gallery_feats, query_feats]
            labels = [gallery_labels, query_labels]
            """
            self.is_equal_query = False

            self.gallery_feats, self.query_feats = gallery_feats, query_feats
            self.gallery_labels, self.query_labels = labels

        else:
            self.is_equal_query = True
            self.gallery_feats = gallery_feats
            self.query_feats = query_feats
            self.gallery_labels = self.query_labels = labels
        self.test_features = np.ascontiguousarray(self.query_feats)
        self.features = np.ascontiguousarray(self.gallery_feats)
        self.num = self.test_features.shape[0]
        self.dim = self.features.shape[1]
        self.index = faiss.IndexFlatL2(self.dim)
        self.index.add(self.features)

    def recall_compute(self,k,image_labels,db, retrive_result=None):
        if retrive_result is None:
            _, retrive_result = self.index.search(self.test_features, k)
            return retrive_result
        else:
            count = 0
            count1 = 0
            for counter1, index1 in enumerate(retrive_result):
                if db[counter1] == 0:
                    count1 += 1
                    a = np.array(retrive_result[counter1])+1
                    st = 0
                    for i, d in enumerate(image_labels.keys()):
                        if counter1+1 in image_labels[d]:
                            st = d
                            break
                    tmp = list(set(a).intersection(set(image_labels[st])))
                    if len(tmp)>1:
                        count+=1
            return count*1.0 / count1

    def fusion_res(self,res1,res2):
        for counter, index in enumerate(res1):
            if res1[counter,1] != res2[counter,1]:
                st = res1[counter,1]
                res1[counter,1] = res2[counter,1]
                for ct1,id1 in enumerate(index):
                    if ct1>=2 and res2[counter,1] == id1:
                        res1[counter, ct1] = st
                        break
        return res1
